SHAP Meets Tensor Networks 논문 리뷰
| Venue | Link |
|---|---|
| NeurIPS 2025 poster | OpenReview, arXiv |
Marzouk, Bassan, Katz [3]의 SHAP Meets Tensor Networks는 exact SHAP 계산을 Tensor Network 관점에서 다시 보는 논문이다. 논문은 제목처럼 Tensor Networks (TNs) 라는 넓은 표현 체계를 다루지만, 가장 강한 tractability 결과는 그중 Tensor Train (TT) 구조에서 나온다. 따라서 이 글의 핵심 질문은 “왜 TT 구조가 exact SHAP 계산을 쉽게 만드는가?”이다.
1 Introduction
SHAP은 현재 XAI에서 가장 널리 쓰이는 local explanation 방법 중 하나이다 [2]. 기본 아이디어는 cooperative game theory의 Shapley value를 feature attribution에 적용하는 것이다 [5]. 모델의 특정 prediction을 설명할 때, 각 feature가 여러 feature subset이라는 context 안에서 평균적으로 얼마나 prediction을 바꾸는지 계산한다. 이 방식은 공정한 attribution이라는 장점이 있지만, 모든 feature subset을 다루어야 하므로 일반적으로 계산이 매우 어렵다.
기존 연구는 이 계산 난점을 모델 구조별로 나누어 다루었다. Decision tree나 tree ensemble처럼 구조가 비교적 명확한 모델에서는 exact SHAP을 polynomial time에 계산할 수 있다 [1]. TreeSHAP 계열의 핵심은 tree path 구조를 이용해 exponential subset enumeration을 피하는 것이다. 하지만 neural network처럼 표현력이 큰 black-box model로 가면 exact SHAP 계산은 NP-hard가 된다. 이 지점이 논문의 출발점이다. 설명이 가장 절실한 모델일수록, exact explanation은 더 계산하기 어려워지는 역설이 생긴다.
논문은 이 간극을 Tensor Networks로 공략한다. Tensor Network는 고차원 tensor를 작은 tensor들의 graph로 분해해 표현하는 방식이다. TT-format은 그중에서도 1차원 chain 형태의 Tensor Network이다 [4]. 원래 Tensor Network는 quantum physics에서 복잡한 many-body state를 효율적으로 표현하기 위해 발전했지만, 최근에는 machine learning에서도 classification, regression, generative modeling, neural network compression 등에 사용된다. 즉 Tensor Network는 단순히 수치선형대수의 저장공간 절약 기법이 아니라, 표현력 있는 model class로 볼 수 있다.
여기서 논문이 XAI 관점에서 Tensor Network를 중요하게 보는 이유가 나온다. 좋은 explanation 대상 모델은 두 성질을 동시에 가져야 한다. 첫째, 현실의 복잡한 함수를 표현할 수 있을 만큼 expressive해야 한다. 둘째, 설명을 계산할 수 있을 만큼 구조적이어야 한다. 너무 단순한 모델은 설명하기 쉽지만 성능이 부족할 수 있고, 너무 자유로운 black-box model은 성능은 좋지만 exact explanation이 어렵다. 논문은 Tensor Network가 이 두 성질 사이의 중간 지점을 제공한다고 본다. 특히 TT처럼 구조가 강한 Tensor Network는 neural network 계열만큼 표현력이 있는 방향으로 확장될 수 있으면서도, contraction과 factorization 구조 덕분에 SHAP 계산을 이론적으로 분석할 수 있다.
이 관점에서 논문의 핵심 메시지는 다음처럼 읽을 수 있다.
XAI에서 Tensor Train이 중요한 이유는 “표현력 있는 모델을 구조적으로 쪼개어, exact SHAP 계산을 병렬화 가능한 tensor contraction 문제로 바꿀 수 있기 때문”이다.
논문은 contribution을 네 갈래로 제시한다. 첫 번째는 일반 Tensor Network에 대한 exact SHAP 계산 framework이다. 임의 구조의 TN에 대해 Marginal SHAP을 계산하는 문제를 tensor contraction으로 정식화한다. 다만 일반 TN은 여전히 어렵다. 논문은 일반 TN에 대한 Marginal SHAP 계산이 #P-hard임도 함께 보인다. 여기서 #P-hard는 “가능한 경우의 수를 세는 counting problem” 계열의 어려움을 뜻한다. 단순히 답이 yes/no인지 묻는 NP-hard 문제보다, 모든 가능한 witness의 수나 합을 계산해야 하는 상황에서 자주 등장한다.
두 번째 contribution은 Tensor Train에 대한 tractability 결과이다. TN 구조를 TT로 제한하면 SHAP 계산을 병렬 계산에서 poly-logarithmic time에 수행할 수 있다는 것이 논문의 가장 중요한 결과다. 복잡도 이론 용어로는 이 문제가 NC^2에 들어간다고 말한다. NC는 많은 processor를 병렬로 쓸 수 있을 때 poly-logarithmic depth의 circuit으로 풀리는 문제 class이다. 따라서 이 결과는 “sequential time이 작다”기보다, “TT 구조에서는 SHAP 계산이 병렬화하기 좋은 형태로 정리된다”는 의미가 강하다.
세 번째 contribution은 이 TT 결과를 다른 모델로 옮기는 것이다. 논문은 decision tree, tree ensemble, linear model, linear RNN을 TT 표현으로 환원할 수 있음을 이용한다. 그 결과 이 모델들에 대한 Marginal SHAP 계산도 TT distribution 아래에서 NC^2에 들어간다고 정리한다. 이 부분은 TreeSHAP 계열 결과를 부정하는 것이 아니라, tree-specific algorithm을 더 큰 tensor representation 관점에서 다시 일반화하는 쪽에 가깝다.
네 번째 contribution은 binarized neural network (BNN)에 대한 fine-grained complexity 분석이다. BNN은 activation이나 weight가 이진값으로 제한된 neural network이다. 논문은 BNN을 TN으로 표현하는 reduction을 이용해 SHAP 계산의 병목이 어디에 있는지 본다. 결론은 흥미롭다. Depth를 고정해도 hard함은 남지만, width를 고정하면 tractable한 영역이 열린다. 즉 이 모델 class에서는 SHAP 계산의 주요 병목이 depth보다 width 쪽에 있다는 메시지를 준다.
Introduction의 전체 흐름을 요약하면 다음과 같다. SHAP은 중요하지만 exact 계산이 어렵다. 기존 exact algorithm은 tree나 linear model 같은 제한된 구조에서 잘 동작한다. Tensor Network는 neural network에 가까운 표현력과 수학적으로 분석 가능한 구조를 동시에 제공한다. 그중 Tensor Train은 chain 구조 덕분에 exact SHAP을 병렬적으로 tractable하게 만들 수 있다. 그래서 이 논문은 “표현력 있는 모델에서 exact explanation이 어디까지 가능한가?”라는 질문에 Tensor Network라는 답을 제시한다.
2 Preliminaries
논문은 먼저 SHAP 계산을 Marginal SHAP 관점에서 정리한다. Feature 집합을 N=\{1,\ldots,n\}이라고 하자. 모델을 M이라고 하고, 설명하려는 입력을 x라고 하자. Feature i의 Shapley value는 feature subset S \subseteq N \setminus \{i\}에 feature i를 추가했을 때 marginal value가 얼마나 바뀌는지 평균낸 값이다.
\phi_i = \sum_{S \subseteq N \setminus \{i\}} \frac{|S|!(n-|S|-1)!}{n!} \left( v(S \cup \{i\}) - v(S) \right)
여기서 핵심은 value function v(S)이다. Marginal SHAP에서는 모르는 feature들을 background distribution에서 marginalize한다. 즉 S에 속한 feature는 설명 대상 입력 x의 값을 고정하고, S 밖 feature는 distribution P에 따라 평균낸다.
v(S) = \mathbb{E}_{z \sim P} \left[ M(x_S, z_{\bar S}) \right]
이 정의는 직관적이다. “feature subset S만 알고 있다면 모델 output의 기대값은 얼마인가?”를 묻는다. 하지만 이 정의를 그대로 계산하려면 S를 모두 순회해야 하고, 각 S마다 expectation도 계산해야 한다. 논문은 이 반복 구조를 Tensor Network contraction으로 묶어 한 번에 다루려 한다.
여기서 \#P-hard, NC, FPT, XP 같은 복잡도 class가 등장한다. 각 용어의 배경은 P와 NP 문제 글의 주변 복잡도 class 정리를 참고하면 된다.
이 논문은 단순히 “계산량이 줄었다”보다 더 복잡도 이론적인 질문을 던진다. 어떤 모델 class에서는 exact SHAP이 근본적으로 어렵고, 어떤 구조를 주면 병렬적으로 tractable해지는지를 구분한다.
3 Tensors, Tensor Networks and Binarized Neural Networks
Tensor는 여러 개의 index를 가진 배열이다. Vector가 1개의 index, matrix가 2개의 index를 가진다면, tensor는 그 이상의 index를 가진다고 보면 된다.
T(i_1,i_2,\ldots,i_n)
Tensor Network는 이런 tensor들을 graph로 연결한 표현이다. 각 node는 작은 tensor이고, edge는 두 tensor가 공유하는 index이다. 공유 index를 따라 두 tensor를 곱한 뒤 그 index에 대해 합산해서 하나의 tensor로 합치는 연산을 contraction이라고 부른다. 일상어의 “축소”라기보다, tensor 사이의 공통 축을 계산으로 소거하는 연산에 가깝다. 더 구체적인 계산 예시는 Tensor-Train Decomposition 글의 multidimensional contraction 부분을 참고하면 된다. Matrix multiplication도 contraction의 한 예다.
예를 들어 두 matrix A(i,k)와 B(k,j)를 곱하면 다음과 같다.
C(i,j) = \sum_k A(i,k)B(k,j)
여기서 k가 contraction되는 index이다. Tensor Network에서는 이런 contraction을 더 큰 graph 위에서 수행한다. 문제는 graph 구조가 일반적이면 contraction 순서를 잘 잡기 어렵고, 계산량도 커질 수 있다는 점이다.
Tensor Train은 Tensor Network 중에서도 chain 구조를 가진다. 고차원 tensor T(i_1,\ldots,i_n)을 다음처럼 core들의 곱으로 표현한다.
T(i_1,\ldots,i_n) = G_1(i_1)G_2(i_2)\cdots G_n(i_n)
각 G_k(i_k)는 scalar가 아니라 matrix이고, 양 끝은 row/column vector처럼 동작한다. 이 구조는 저장공간을 크게 줄일 수 있다. 하지만 이 논문에서 더 중요한 점은 저장공간보다 계산 graph가 chain이라는 사실이다. Chain 구조는 contraction을 순차적으로 혹은 병렬 prefix 방식으로 정리하기 좋다. 이 때문에 TT가 exact SHAP 계산의 tractable subclass가 된다.
논문은 또한 Binarized Neural Network를 다룬다. BNN은 weight나 activation이 이진값을 갖는 neural network이다. 이런 모델은 일반 neural network보다 단순해 보이지만, exact SHAP 계산 관점에서는 여전히 어렵다. 논문은 BNN을 Tensor Network로 바꾸어 분석하면서, width와 depth가 SHAP 계산 난이도에 어떻게 다르게 작용하는지 분리한다.
4 Provably Exact SHAP Explanations for TNs: A General Framework
논문은 먼저 일반 Tensor Network에 대해 exact Marginal SHAP을 계산하는 framework를 만든다. 핵심 객체는 Marginal SHAP Tensor이다. 보통 SHAP value는 feature별 scalar 값 \phi_i로 생각하지만, 논문은 이를 tensor 형태로 모은다. 이렇게 하면 모든 feature의 SHAP value를 하나의 contraction 결과로 표현할 수 있다.
논문이 하는 일은 크게 세 단계다.
- SHAP weight와 coalition 구조를 담는 tensor를 만든다.
- 모델 M과 distribution P로부터 marginal value tensor를 만든다.
- 두 tensor를 contraction해서 Marginal SHAP Tensor를 얻는다.
직관적으로 보면 SHAP 공식의 두 부분을 분리한 것이다. 하나는 Shapley weight처럼 모델과 무관한 조합론적 부분이고, 다른 하나는 v(S)처럼 모델과 distribution에 의존하는 부분이다. 논문은 이 둘을 tensor로 표현해 contraction으로 합친다.
이 framework는 중요하다. 왜냐하면 SHAP 계산을 “subset을 도는 알고리즘”에서 “Tensor Network를 구성하고 contraction하는 문제”로 바꾸기 때문이다. 이렇게 바꾸면 모델 class별로 같은 질문을 던질 수 있다.
이 모델과 distribution을 어떤 Tensor Network로 표현할 수 있는가? 그리고 그 network의 contraction은 tractable한가?
하지만 일반 TN에서는 좋은 소식만 있는 것은 아니다. 논문은 일반 Tensor Network에 대한 Marginal SHAP 계산이 #P-hard임을 보인다. 즉 Tensor Network라고 해서 자동으로 SHAP 계산이 쉬워지는 것은 아니다. 중요한 것은 Tensor Network라는 넓은 표현이 아니라, 그 안의 topology이다. 이후 TT 구조가 핵심으로 등장하는 이유가 여기에 있다.
5 Provably Exact and Tractable SHAP for TTs, and Other ML Models
5.1 Provably exact and tractable SHAP explanations for TTs
Tensor Train에서는 상황이 달라진다. 모델 M도 TT로 표현되고, background distribution P도 TT로 표현된다고 하자. 그러면 논문은 Marginal SHAP Tensor 자체도 TT로 표현할 수 있음을 보인다. 이 말은 단순히 결과가 compact하다는 뜻이 아니다. SHAP 계산에 필요한 전체 과정을 TT core들의 조합과 contraction으로 정리할 수 있다는 뜻이다.
논문에서 등장하는 구성 요소는 대략 다음과 같다.
- 모델의 TT core
- distribution의 TT core
- coalition을 제어하는 router tensor
- Shapley weight를 담는 weighted coalitional tensor
Router tensor는 feature가 coalition 안에 들어 있는지에 따라, 해당 feature를 설명 대상 입력 x_i로 고정할지 아니면 distribution에서 marginalize할지를 결정한다. 즉 SHAP의 “feature가 알려졌을 때와 알려지지 않았을 때”를 tensor graph 내부의 routing으로 구현한다.
이 구성이 중요한 이유는 모든 coalition S를 명시적으로 나열하지 않아도 된다는 점이다. Coalition 상태는 TT core를 통과하며 누적되고, Shapley weight도 별도의 TT 구조로 함께 들어간다. 결국 exponential한 subset sum이 TT contraction 안에 압축된다.
논문의 주요 결과는 Marginal SHAP for TT가 NC^2에 속한다는 것이다. NC^2는 poly-logarithmic parallel time과 polynomial processor로 계산 가능하다는 뜻이다. 여기서 “poly-logarithmic”은 입력 크기 n에 대해 (\log n)^k 꼴의 시간 복잡도를 말한다. 따라서 이 결과는 TT 기반 SHAP 계산이 병렬화 관점에서 매우 좋은 구조를 가진다는 의미다.
이 지점이 Introduction에서 말한 XAI와 TT의 연결을 가장 잘 보여준다. TT는 단지 parameter 수를 줄이는 format이 아니다. TT 구조는 exact SHAP의 combinatorial sum을 병렬 contraction 문제로 바꿀 수 있는 계산 구조를 제공한다.
5.2 Tightening complexity results for other ML models via TTs
논문은 TT 결과를 다른 모델 class로 확장한다. 핵심은 많은 모델을 TT로 표현하거나 TT로 환원할 수 있다는 점이다. Decision tree, tree ensemble, linear model, linear RNN은 적절한 construction을 통해 TT 구조 안에 넣을 수 있다.
이 결과는 다음처럼 이해하면 좋다. TreeSHAP은 tree라는 특수 구조를 이용해 SHAP을 빠르게 계산한다. 이 논문은 그보다 더 추상적인 층에서 “tree도 TT로 볼 수 있고, TT의 SHAP 계산은 병렬적으로 tractable하다”고 말한다. 즉 TreeSHAP류 알고리즘을 대체한다기보다, 왜 여러 모델 class에서 exact SHAP이 가능해지는지를 Tensor Network라는 공통 언어로 설명한다.
특히 distribution P도 TT로 표현할 수 있다는 점이 중요하다. 일반적인 SHAP 논의에서는 feature independence를 가정하거나, background sample을 단순하게 다루는 경우가 많다. 하지만 TT distribution은 feature dependency를 더 구조적으로 담을 수 있다. 따라서 이 framework는 모델뿐 아니라 data distribution의 표현력도 함께 고려한다.
정리하면 이 section의 메시지는 다음과 같다.
- TT model + TT distribution이면 Marginal SHAP이 병렬적으로 tractable하다.
- 여러 기존 model class는 TT 표현으로 옮겨갈 수 있다.
- 따라서 TT 결과는 tree, ensemble, linear model, linear RNN의 SHAP 복잡도 결과를 하나의 틀에서 재해석한다.
6 Fine-Grained Analysis of SHAP Computation for BNNs
논문 후반부는 Binarized Neural Network에 대한 complexity 분석이다. BNN은 모든 값이 이진적이라 일반 neural network보다 훨씬 단순해 보인다. 하지만 exact SHAP 계산 문제는 여전히 섬세하다.
논문은 depth와 width를 따로 본다. Depth는 layer 수이고, width는 각 layer의 neuron 수로 생각하면 된다. 일반적으로 neural network 복잡도를 말할 때 depth를 많이 떠올리지만, 이 논문의 SHAP 계산 결과는 width 쪽을 더 강하게 지목한다.
주요 결론은 세 가지다.
첫째, depth를 고정해도 SHAP 계산은 여전히 어렵다. 즉 layer 수가 작다고 해서 exact SHAP이 자동으로 쉬워지는 것은 아니다.
둘째, width를 고정하면 문제는 XP에 들어간다. 이는 width가 작은 상수라면 입력 크기에 대해 polynomial time 계산이 가능하다는 뜻이다. 다만 polynomial의 차수가 width에 의존할 수 있으므로, 실용적으로 항상 빠르다는 뜻은 아니다.
셋째, width와 sparsity를 함께 parameter로 두면 FPT 결과를 얻는다. Sparsity는 network 연결이 얼마나 희소한지를 나타낸다. Width가 작고 연결도 sparse하면 exact SHAP 계산이 더 강한 의미에서 tractable해진다.
이 분석은 XAI 관점에서 꽤 의미가 있다. 모델의 설명 가능성을 논할 때 단순히 “neural network라서 어렵다”라고 말하는 대신, 어떤 구조적 parameter가 explanation 계산을 어렵게 만드는지 분리해 보여주기 때문이다. 이 논문에서는 BNN의 경우 width가 중요한 bottleneck으로 드러난다.
7 Limitations and Future Work
논문의 장점은 exact SHAP 계산을 Tensor Network와 복잡도 이론의 언어로 정리했다는 점이다. 하지만 이 결과를 바로 실무 library 속도 개선으로 읽으면 안 된다. 주요 결과는 algorithm engineering보다 theoretical tractability에 가깝다. 특히 NC^2 결과는 병렬 계산 모델에서의 좋은 구조를 말하지만, 실제 GPU/CPU 구현에서 어느 정도 성능이 나오는지는 별도의 문제다.
논문 밖 해석으로 남겨둘 limitation. TT로 표현된다는 사실만으로 효율성이 자동으로 보장되지는 않는다. 고정된 tensor와 mode ordering에 대해 exact TT-rank는 unfolding rank로 결정되며, 이 rank가 크면 TT 표현도 compact하지 않다. 다만 approximation에서는 tolerance에 따라 낮은 rank TT를 선택할 수 있고, mode ordering에 따라서도 TT-rank가 달라질 수 있다. 따라서 이 논문의 tractability 결과는 모델과 distribution이 낮은 TT-rank를 갖거나, 충분히 낮은 rank로 근사 가능한 TT 표현을 가질 때 실질적인 의미가 커진다. 논문이 효율적인 rank나 mode ordering을 자동으로 탐색하는 방법까지 제시하는 것은 아니다.
Distribution을 TT로 표현하는 부분도 실무적으로는 중요한 질문을 남긴다. 복잡한 tabular dependency나 real-world feature distribution을 어떤 TT-rank로 잘 근사할 수 있는지, 그리고 그 근사가 SHAP value에 어떤 영향을 주는지는 후속 연구가 더 필요하다.
그래도 논문의 방향은 분명하다. XAI에서 exact explanation을 포기하지 않으려면, 모델 class의 구조를 더 적극적으로 사용해야 한다. Tensor Network는 그 구조를 표현하는 하나의 강력한 언어다.
8 Conclusion
이 논문은 SHAP 계산을 “feature subset을 전부 도는 문제”가 아니라 “모델과 distribution을 Tensor Network로 표현하고 contraction하는 문제”로 다시 쓴다. 일반 Tensor Network에서는 여전히 #P-hard하지만, Tensor Train으로 제한하면 exact Marginal SHAP이 NC^2에 들어간다. 이 결과는 TT가 XAI에서 중요한 이유를 잘 보여준다.
Tensor Train은 고차원 tensor를 압축하는 format일 뿐 아니라, SHAP의 exponential coalition sum을 구조화하는 계산 graph이다. 그래서 TT는 표현력과 tractability 사이의 중간 지점을 제공한다. 이 점이 논문 Introduction에서 강조하는 핵심이다. 설명하기 쉬운 단순 모델과 설명하기 어려운 black-box model 사이에, Tensor Network 기반 모델이라는 더 구조적인 선택지가 있다는 것이다.
이 논문은 tensor representation과 SHAP 계산 복잡도를 직접 연결한다. 핵심은 “구조화된 tensor representation이 exact explanation의 계산 복잡도 자체를 바꿀 수 있다”는 점이다.